{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# 多头注意力\n",
    ":label:`sec_multihead-attention`\n",
    "\n",
    "\n",
    "在实践中，当给定相同的查询、键和值的集合时，我们希望模型可以基于相同的注意力机制学习到不同的行为，然后将不同的行为作为知识组合起来，例如捕获序列内各种范围的依赖关系（例如，短距离依赖和长距离依赖）。因此，允许注意力机制组合使用查询、键和值的不同 子空间表示（representation subspaces）可能是有益的。\n",
    "\n",
    "为此，与使用单独一个注意力汇聚不同，我们可以用独立学习得到的 $h$ 组不同的 线性投影（linear projections）来变换查询、键和值。然后，这 $h$ 组变换后的查询、键和值将并行地送到注意力汇聚中。最后，将这 $h$ 个注意力汇聚的输出拼接在一起，并且通过另一个可以学习的线性投影进行变换，以产生最终输出。这种设计被称为 多头注意力，其中 $h$ 个注意力汇聚输出中的每一个输出都被称作一个 *头（head）* :cite:`Vaswani.Shazeer.Parmar.ea.2017`。 图 :numref:`fig_multi-head-attention` 展示了使用全连接层来实现可学习的线性变换的多头注意力。\n",
    "\n",
    "\n",
    "![多头注意力，多个头连结然后线性变换。](https://zh-v2.d2l.ai/_images/multi-head-attention.svg)\n",
    ":label:`fig_multi-head-attention`\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "## 模型\n",
    "\n",
    "在实现多头注意力之前，让我们用数学语言将这个模型形式化地描述出来。给定查询 $\\mathbf{q} \\in \\mathbb{R}^{d_q}$、键 $\\mathbf{k} \\in \\mathbb{R}^{d_k}$ 和值 $\\mathbf{v} \\in \\mathbb{R}^{d_v}$，每个注意力头 $\\mathbf{h}_i$  ($i = 1, \\ldots, h$) 的计算方法为：\n",
    "\n",
    "$$\\mathbf{h}_i = f(\\mathbf W_i^{(q)}\\mathbf q, \\mathbf W_i^{(k)}\\mathbf k,\\mathbf W_i^{(v)}\\mathbf v) \\in \\mathbb R^{p_v},$$\n",
    "\n",
    "其中，可学习的参数包括 $\\mathbf W_i^{(q)}\\in\\mathbb R^{p_q\\times d_q}$、$\\mathbf W_i^{(k)}\\in\\mathbb R^{p_k\\times d_k}$ 和 $\\mathbf W_i^{(v)}\\in\\mathbb R^{p_v\\times d_v}$，以及注意力汇聚的函数 $f$。$f$ 可以是 :numref:`sec_attention-scoring-functions` 中的加性注意力和缩放点积注意力。多头注意力的输出需要经过另一个线性转换，它对应着 $h$ 个头连结后的结果，因此其可学习参数是 $\\mathbf W_o\\in\\mathbb R^{p_o\\times h p_v}$：\n",
    "\n",
    "$$\\mathbf W_o \\begin{bmatrix}\\mathbf h_1\\\\\\vdots\\\\\\mathbf h_h\\end{bmatrix} \\in \\mathbb{R}^{p_o}.$$\n",
    "\n",
    "基于这种设计，每个头都可能会关注输入的不同部分。可以表示比简单加权平均值更复杂的函数。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils\n",
    "%load ../utils/Functions.java\n",
    "%load ../utils/PlotUtils.java\n",
    "\n",
    "%load ../utils/attention/Chap10Utils.java\n",
    "%load ../utils/attention/DotProductAttention.java\n",
    "%load ../utils/attention/MultiHeadAttention.java\n",
    "%load ../utils/attention/PositionalEncoding.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 6
   },
   "source": [
    "为了允许并行计算多个头，下面的`MultiHeadAttention` 类使用两个如下定义的转置函数。具体来说，`transposeOutput` 函数颠倒了 `transposeQkv` 函数的操作。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static NDArray transposeQkv(NDArray X, int numHeads) {\n",
    "    // Shape of input `X`:\n",
    "    // (`batchSize`, no. of queries or key-value pairs, `numHiddens`).\n",
    "    // Shape of output `X`:\n",
    "    // (`batchSize`, no. of queries or key-value pairs, `numHeads`,\n",
    "    // `numHiddens` / `numHeads`)\n",
    "    X = X.reshape(X.getShape().get(0), X.getShape().get(1), numHeads, -1);\n",
    "\n",
    "    // Shape of output `X`:\n",
    "    // (`batchSize`, `numHeads`, no. of queries or key-value pairs,\n",
    "    // `numHiddens` / `numHeads`)\n",
    "    X = X.transpose(0, 2, 1, 3);\n",
    "\n",
    "    // Shape of `output`:\n",
    "    // (`batchSize` * `numHeads`, no. of queries or key-value pairs,\n",
    "    // `numHiddens` / `numHeads`)\n",
    "    return X.reshape(-1, X.getShape().get(2), X.getShape().get(3));\n",
    "}\n",
    "\n",
    "public static NDArray transposeOutput(NDArray X, int numHeads) {\n",
    "    X = X.reshape(-1, numHeads, X.getShape().get(1), X.getShape().get(2));\n",
    "    X = X.transpose(0, 2, 1, 3);\n",
    "    return X.reshape(X.getShape().get(0), X.getShape().get(1), -1);\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 3
   },
   "source": [
    "## 实现\n",
    "\n",
    "在实现过程中，我们选择缩放点积注意力作为每一个注意力头。为了避免计算成本和参数数量的大幅增长，我们设定 $p_q = p_k = p_v = p_o / h$。值得注意的是，如果我们将查询、键和值的线性变换的输出数量设置为 $p_q h = p_k h = p_v h = p_o$，则可以并行计算 $h$ 个头。在下面的实现中，$p_o$ 是通过参数 `numHiddens` 指定的。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static class MultiHeadAttention extends AbstractBlock {\n",
    "\n",
    "    private int numHeads;\n",
    "    public DotProductAttention attention;\n",
    "    private Linear W_k;\n",
    "    private Linear W_q;\n",
    "    private Linear W_v;\n",
    "    private Linear W_o;\n",
    "\n",
    "    public MultiHeadAttention(int numHiddens, int numHeads, float dropout, boolean useBias) {\n",
    "        this.numHeads = numHeads;\n",
    "\n",
    "        attention = new DotProductAttention(dropout);\n",
    "\n",
    "        W_q = Linear.builder().setUnits(numHiddens).optBias(useBias).build();\n",
    "        addChildBlock(\"W_q\", W_q);\n",
    "\n",
    "        W_k = Linear.builder().setUnits(numHiddens).optBias(useBias).build();\n",
    "        addChildBlock(\"W_k\", W_k);\n",
    "\n",
    "        W_v = Linear.builder().setUnits(numHiddens).optBias(useBias).build();\n",
    "        addChildBlock(\"W_v\", W_v);\n",
    "\n",
    "        W_o = Linear.builder().setUnits(numHiddens).optBias(useBias).build();\n",
    "        addChildBlock(\"W_o\", W_o);\n",
    "\n",
    "        Dropout dropout1 = Dropout.builder().optRate(dropout).build();\n",
    "        addChildBlock(\"dropout\", dropout1);\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    protected NDList forwardInternal(\n",
    "            ParameterStore ps,\n",
    "            NDList inputs,\n",
    "            boolean training,\n",
    "            PairList<String, Object> params) {\n",
    "        // Shape of `queries`, `keys`, or `values`:\n",
    "        // (`batchSize`, no. of queries or key-value pairs, `numHiddens`)\n",
    "        // Shape of `validLens`:\n",
    "        // (`batchSize`,) or (`batchSize`, no. of queries)\n",
    "        // After transposing, shape of output `queries`, `keys`, or `values`:\n",
    "        // (`batchSize` * `numHeads`, no. of queries or key-value pairs,\n",
    "        // `numHiddens` / `numHeads`)\n",
    "        NDArray queries = inputs.get(0);\n",
    "        NDArray keys = inputs.get(1);\n",
    "        NDArray values = inputs.get(2);\n",
    "        NDArray validLens = inputs.get(3);\n",
    "        // On axis 0, copy the first item (scalar or vector) for\n",
    "        // `numHeads` times, then copy the next item, and so on\n",
    "        validLens = validLens.repeat(0, numHeads);\n",
    "\n",
    "        queries =\n",
    "                transposeQkv(\n",
    "                        W_q.forward(ps, new NDList(queries), training, params).get(0),\n",
    "                        numHeads);\n",
    "        keys =\n",
    "                transposeQkv(\n",
    "                        W_k.forward(ps, new NDList(keys), training, params).get(0), numHeads);\n",
    "        values =\n",
    "                transposeQkv(\n",
    "                        W_v.forward(ps, new NDList(values), training, params).get(0), numHeads);\n",
    "\n",
    "        // Shape of `output`: (`batchSize` * `numHeads`, no. of queries,\n",
    "        // `numHiddens` / `numHeads`)\n",
    "        NDArray output =\n",
    "                attention\n",
    "                        .forward(\n",
    "                                ps,\n",
    "                                new NDList(queries, keys, values, validLens),\n",
    "                                training,\n",
    "                                params)\n",
    "                        .get(0);\n",
    "\n",
    "        // Shape of `outputConcat`:\n",
    "        // (`batchSize`, no. of queries, `numHiddens`)\n",
    "        NDArray outputConcat = transposeOutput(output, numHeads);\n",
    "        return new NDList(W_o.forward(ps, new NDList(outputConcat), training, params).get(0));\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public Shape[] getOutputShapes(Shape[] inputShapes) {\n",
    "        throw new UnsupportedOperationException(\"Not implemented\");\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public void initializeChildBlocks(\n",
    "            NDManager manager, DataType dataType, Shape... inputShapes) {\n",
    "        try (NDManager sub = manager.newSubManager()) {\n",
    "            NDArray queries = sub.zeros(inputShapes[0], dataType);\n",
    "            NDArray keys = sub.zeros(inputShapes[1], dataType);\n",
    "            NDArray values = sub.zeros(inputShapes[2], dataType);\n",
    "            NDArray validLens = sub.zeros(inputShapes[3], dataType);\n",
    "            validLens = validLens.repeat(0, numHeads);\n",
    "\n",
    "            ParameterStore ps = new ParameterStore(sub, false);\n",
    "\n",
    "            W_q.initialize(manager, dataType, queries.getShape());\n",
    "            W_k.initialize(manager, dataType, keys.getShape());\n",
    "            W_v.initialize(manager, dataType, values.getShape());\n",
    "\n",
    "            queries =\n",
    "                    transposeQkv(W_q.forward(ps, new NDList(queries), false).get(0), numHeads);\n",
    "            keys = transposeQkv(W_k.forward(ps, new NDList(keys), false).get(0), numHeads);\n",
    "            values = transposeQkv(W_v.forward(ps, new NDList(values), false).get(0), numHeads);\n",
    "\n",
    "            NDList list = new NDList(queries, keys, values, validLens);\n",
    "            attention.initialize(sub, dataType, list.getShapes());\n",
    "            NDArray output = attention.forward(ps, list, false).head();\n",
    "            NDArray outputConcat = Chap10Utils.transposeOutput(output, numHeads);\n",
    "\n",
    "            W_o.initialize(manager, dataType, outputConcat.getShape());\n",
    "        }\n",
    "    }\n",
    "}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 9
   },
   "source": [
    "让我们使用键和值相同的小例子来测试我们编写的 `MultiHeadAttention` 类。多头注意力输出的形状是 (`batchSize`, `numQueries`, `numHiddens`)。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int numHiddens = 100;\n",
    "int numHeads = 5;\n",
    "MultiHeadAttention attention = new MultiHeadAttention(numHiddens, numHeads, 0.5f, false);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int batchSize = 2;\n",
    "int numQueries = 4;\n",
    "int numKvpairs = 6;\n",
    "NDArray validLens = manager.create(new float[] {3, 2});\n",
    "NDArray X = manager.ones(new Shape(batchSize, numQueries, numHiddens));\n",
    "NDArray Y = manager.ones(new Shape(batchSize, numKvpairs, numHiddens));\n",
    "\n",
    "ParameterStore ps = new ParameterStore(manager, false);\n",
    "NDList input = new NDList(X, Y, Y, validLens);\n",
    "attention.initialize(manager, DataType.FLOAT32, input.getShapes());\n",
    "NDList result = attention.forward(ps, input, false);\n",
    "result.get(0).getShape();\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 13
   },
   "source": [
    "## 小结\n",
    "\n",
    "* 多头注意力融合了来自于相同的注意力汇聚产生的不同的知识，这些知识的不同来源于相同的查询、键和值的不同的子空间表示。\n",
    "* 基于适当的张量操作，可以实现多头注意力的并行计算。\n",
    "\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 分别可视化这个实验中的多个头的注意力权重。\n",
    "1. 假设我们已经拥有一个完成训练的基于多头注意力的模型，现在希望修剪最不重要的注意力头以提高预测速度。应该如何设计实验来衡量注意力头的重要性？\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
